
import cv
import numpy as np
import matplotlib.pyplot as plt

# Get occurrence of each gray 
def GetHistogram(image, width, height, depth):
    grayLevel = 2**depth
    a = [0]*(grayLevel)
    for i in range(height):
        for j in range(width):
            iGray = int(image[i,j])
            a[iGray] = a[iGray] + 1
    return a


# Calculator new intensity value
def EqualizeHistorgram(originalHist, width, height, depth):
    grayLevel = 2**depth
    histMappingTable = [0]*(grayLevel)
    MN = width*height
    tmpSum = 0
    for i in range(grayLevel):
        tmpSum += originalHist[i]
        histMappingTable[i] = round((2**depth-1)*tmpSum*1.0/MN)
    return histMappingTable


# Apply equalized histogram to original image
def EqualizeImage(originalImage, histMappingTable):
    targetImage = cv.CreateImage((originalImage.width, originalImage.height), originalImage.depth, 1)
    for i in range(originalImage.height):
        for j in range(originalImage.width):
            iGray = int(originalImage[i,j])
            targetImage[i,j] = histMappingTable[iGray]
    return targetImage

def ShowHistogramDiagram(histogram, depth, windowsTitle):
    pass
    # iLHist = cv.CreateImage((width, height), depth, channels)
    # for i in range(1,256):
    #     a[i] = a[i]+a[i-1] 
    #     S = max(a)
    # for k in range(256):
    #     a[k] = a[k]*255/S
    #     x = (k,255)y = (k,255-a[k])
    #     cv.Line(iLHist,x,y,color)


if __name__ == '__main__':
    rawImg = cv.LoadImage('./img/Fig2.jpg', cv.CV_LOAD_IMAGE_GRAYSCALE)
    #img2 = cv2.imread('./img/Fig1.jpg', cv2.CV_LOAD_IMAGE_GRAYSCALE)
    #cv2.namedWindow('Original Image', cv2.CV_WINDOW_AUTOSIZE)
    #cv2.imshow( "Original Image", rawImg);
    cv.ShowImage('Original Image',rawImg)

    rawHist = GetHistogram(rawImg, rawImg.width, rawImg.height, int(rawImg.depth))
    #plt.plot(rawHist)
   
    # the histogram of the data
    n, bins, patches = plt.hist(img2.flatten(), 50, facecolor='g', alpha=0.75)
    plt.grid(True)
    plt.figure()
    #ShowHistogramDiagram(rawHist, int(rawImg.depth), "Original Image Histogram")
    histMapping = EqualizeHistorgram(rawHist, rawImg.width, rawImg.height, int(rawImg.depth))

    resultImg = EqualizeImage(rawImg, histMapping)
    cv.ShowImage('Result Image',resultImg)

    resultHist = GetHistogram(resultImg, rawImg.width, rawImg.height, int(rawImg.depth))
    plt.plot(resultHist)
    plt.axis([0, 2**int(rawImg.depth), 0, round(max(rawHist)*1.3)])
    plt.grid(True)

    plt.show()

    cv.waitKey(0);